library(tidyverse)
library(ggpubr)
library(hrbrthemes)
library(ggrepel)
library(ggimage)
library(grid)
library(kernlab) 
library(quanteda)
library(caret)
library(MLmetrics)


get_label_counts <- function(df, label_column) {
  label_counts <- df %>%
    group_by(!!sym(label_column)) %>%
    summarise(count = n()) %>%
    arrange(desc(count))
  
  return(label_counts)
}

################################################################################
################################################################################
################################################################################

# Function to create the plot for a given method and label type
create_plot <- function(method_input, label_type_input) {
  
  df_filtered <- df_pred_prompting %>% 
    filter(method == !!method_input, label_type == !!label_type_input) %>% 
    mutate(dataset_label = pick_first(dataset_label))
  
  print(df_filtered)
  
  conf_matrix_prompting <- df_filtered %>%
    count(prompting_strategy, expected_label, dataset_label) %>%
    rename(Group = prompting_strategy, Actual = dataset_label, Predicted = expected_label, Count = n) %>%
    mutate(Actual = substr(Actual, 1, 1), 
           Predicted = substr(Predicted, 1, 1),
           Color = ifelse(factor(Actual, levels = rev(unique(Actual))) == Predicted, "green", "red")) |> 
    mutate(across(c(Actual, Predicted), ~ case_when(
      . == "d" ~ "a",
      . == "e" ~ "c",
      . == "n" ~ "o",
      TRUE ~ .
    )))
  
  p <- ggplot(conf_matrix_prompting, aes(x = Predicted, y = Actual, fill = Color, alpha = Count)) +
    geom_tile(color = "white") +
    coord_equal() +
    geom_text(aes(label = Count), vjust = 0.5, fontface = "bold", alpha = 1) +
    scale_fill_manual(values = c("green", "red")) +
    scale_alpha(range = c(0.4, 1)) +
    facet_wrap(~ Group, nrow = 2) +
    guides(fill = "none", label = "none") +
    theme_minimal() +
    labs(title = paste(method_input, "-", label_type_input)) +
    theme(plot.title = element_text(size = 10, face = "bold", hjust = 0.5),
          legend.position = "none")
  
  return(p)
}


################################################################################
################################################################################
################################################################################

train_and_evaluate_model_one_tune_then_eval <- function(
    df,
    label_column,
    label_levels,
    # seeds
    seeds        = c(42, 123, 999, 2024, 2025),
    tuning_seed  = NULL,                 # if NULL, uses seeds[1]
    eval_seeds   = NULL,                 # if NULL, uses first 3 of 'seeds'
    # models & grids
    kernels      = c("svmLinear","svmRadial"),
    C_grid       = 2^seq(-5, 5, by = 1),
    sigma_grid   = NULL,                 # NULL = auto via sigest (using tuning_seed)
    # saving
    save_models  = FALSE,
    dataset_name = NULL,
    model_base   = "models/models",
    feature_base = "models/features"
) {
  # --- helpers (copied from your original) --------------------------------
  ensure_dir <- function(path) dir.create(dirname(path), recursive = TRUE, showWarnings = FALSE)
  fmtC   <- function(x) format(x, scientific = FALSE, trim = TRUE)
  fmtSig <- function(x) format(x, scientific = TRUE, trim = TRUE)
  model_path_for <- function(km, ds, seed, C, sigma = NULL) {
    sub <- if (km == "svmLinear") {
      sprintf("%s/%s/seed_%s/C_%s/model.rds", km, ds, seed, fmtC(C))
    } else {
      sprintf("%s/%s/seed_%s/sigma_%s/C_%s/model.rds", km, ds, seed, fmtSig(sigma), fmtC(C))
    }
    file.path(model_base, sub)
  }
  feature_path_for <- function(km, ds, seed, C, sigma = NULL) {
    sub <- if (km == "svmLinear") {
      sprintf("%s/%s/seed_%s/C_%s/features.rds", km, ds, seed, fmtC(C))
    } else {
      sprintf("%s/%s/seed_%s/sigma_%s/C_%s/features.rds", km, ds, seed, fmtSig(sigma), fmtC(C))
    }
    file.path(feature_base, sub)
  }
  save_artifacts <- function(model_obj, feat_spec, km, ds, seed, C, sigma = NULL) {
    mp <- model_path_for(km, ds, seed, C, sigma)
    fp <- feature_path_for(km, ds, seed, C, sigma)
    ensure_dir(mp); ensure_dir(fp)
    saveRDS(model_obj, file = mp)
    saveRDS(feat_spec, file = fp)
    message(sprintf("Saved -> %s  |  %s", mp, fp))
  }
  build_features_training <- function(text_vec, seed = 1, stop_lang = "de") {
    set.seed(seed)
    toks <- quanteda::tokens(text_vec,
                             remove_punct   = TRUE,
                             remove_numbers = TRUE,
                             remove_symbols = TRUE) |>
      quanteda::tokens_tolower() |>
      quanteda::tokens_remove(quanteda::stopwords(stop_lang))
    dfm_train <- quanteda::dfm(toks)
    x <- as.matrix(dfm_train)
    spec <- list(
      vocab         = quanteda::featnames(dfm_train),
      stop_lang     = stop_lang,
      tolower       = TRUE,
      remove_punct  = TRUE,
      remove_numbers= TRUE,
      remove_symbols= TRUE
    )
    list(x = x, spec = spec)
  }
  # ------------------------------------------------------------------------
  
  if (is.null(tuning_seed)) tuning_seed <- seeds[[1]]
  if (is.null(eval_seeds))  eval_seeds  <- utils::head(seeds, 3)
  
  # helper to get sigma candidates ONCE for tuning (tuning_seed)
  get_sigma_candidates <- function(seed) {
    set.seed(seed)
    corpus_obj <- quanteda::corpus(df, text_field = "text")
    tokens_obj <- quanteda::tokens(corpus_obj,
                                   remove_punct = TRUE, remove_numbers = TRUE, remove_symbols = TRUE) |>
      quanteda::tokens_tolower() |>
      quanteda::tokens_remove(quanteda::stopwords("de"))
    dfm_obj <- quanteda::dfm(tokens_obj)
    tr_idx <- sample(seq_len(quanteda::ndoc(dfm_obj)), size = 0.80 * quanteda::ndoc(dfm_obj))
    X_train_tmp <- as.data.frame(as.matrix(dfm_obj[tr_idx, ]))
    if (ncol(X_train_tmp) > 0) {
      nzv <- caret::nearZeroVar(X_train_tmp, saveMetrics = TRUE)
      if (any(nzv$zeroVar)) X_train_tmp <- X_train_tmp[, !nzv$zeroVar, drop = FALSE]
      constant_columns <- vapply(X_train_tmp, function(col) length(unique(col)) == 1, logical(1))
      if (any(constant_columns)) X_train_tmp <- X_train_tmp[, !constant_columns, drop = FALSE]
    }
    sig <- kernlab::sigest(as.matrix(X_train_tmp), frac = 1)
    sort(unique(as.numeric(outer(as.numeric(sig), c(0.5, 1, 2), `*`))))
  }
  
  kernel_results <- list()
  for (km in kernels) {
    message("==== Kernel: ", km, " ====")
    
    # -------- 1) TUNE ONCE (tuning_seed) ---------------------------------
    if (km == "svmLinear") {
      tune_grid <- expand.grid(C = C_grid)
    } else {
      if (is.null(sigma_grid)) {
        sigma_candidates <- get_sigma_candidates(tuning_seed)
        tune_grid <- expand.grid(sigma = sigma_candidates, C = C_grid)
      } else {
        tune_grid <- expand.grid(sigma = sigma_grid, C = C_grid)
      }
    }
    
    tune_res <- train_and_evaluate_model_v2(
      df = df,
      label_column = label_column,
      label_levels = label_levels,
      model_method = km,
      tune_grid    = tune_grid,
      tune_length  = NULL,
      seed_val     = tuning_seed
    )
    best_tune <- as.data.frame(tune_res$best_tune)
    
    if (km == "svmLinear") {
      C_star <- best_tune$C[[1]]
      fixed_grid <- data.frame(C = C_star)
    } else {
      C_star     <- best_tune$C[[1]]
      sigma_star <- best_tune$sigma[[1]]
      fixed_grid <- data.frame(sigma = sigma_star, C = C_star)
    }
    
    # -------- 2) EVALUATE across eval_seeds (3 seeds) --------------------
    per_seed <- lapply(eval_seeds, function(s) {
      r <- train_and_evaluate_model_v2(
        df = df,
        label_column = label_column,
        label_levels = label_levels,
        model_method = km,
        tune_grid    = fixed_grid,  # single-point eval at tuned params
        tune_length  = NULL,
        seed_val     = s
      )
      data.frame(seed = s,
                 macro_f1 = r$macro_f1,
                 C = C_star,
                 sigma = if (km == "svmRadial") sigma_star else NA_real_,
                 stringsAsFactors = FALSE)
    })
    per_seed_df <- dplyr::bind_rows(per_seed)
    
    # (If you still want per-class F1s across seeds, you could rbind r$f1_scores similarly.)
    
    agg <- list(
      tuned_on_seed   = tuning_seed,
      fixed_params    = if (km == "svmLinear")
        list(C = C_star) else list(C = C_star, sigma = sigma_star),
      eval_seeds      = eval_seeds,
      avg_macro_f1    = mean(per_seed_df$macro_f1, na.rm = TRUE),
      sd_macro_f1     = stats::sd(per_seed_df$macro_f1, na.rm = TRUE)
    )
    
    # -------- 3) Optional: save fixed models for each eval seed ----------
    if (isTRUE(save_models)) {
      if (is.null(dataset_name) || !nzchar(dataset_name)) {
        stop("save_models=TRUE but dataset_name is NULL/empty. Provide dataset_name so paths can be built.")
      }
      y_all <- factor(df[[label_column]], levels = label_levels)
      
      if (km == "svmLinear") {
        ctrl <- caret::trainControl(method = "none")
        for (s in eval_seeds) {
          feats <- build_features_training(df[["text"]], seed = s, stop_lang = "de")
          mdl <- caret::train(x = feats$x, y = y_all,
                              method = "svmLinear",
                              trControl = ctrl,
                              tuneGrid  = data.frame(C = C_star))
          save_artifacts(mdl, feats$spec, km, dataset_name, s, C = C_star, sigma = NULL)
        }
      } else {
        ctrl <- caret::trainControl(method = "none")
        for (s in eval_seeds) {
          feats <- build_features_training(df[["text"]], seed = s, stop_lang = "de")
          mdl <- caret::train(x = feats$x, y = y_all,
                              method = "svmRadial",
                              trControl = ctrl,
                              tuneGrid  = data.frame(sigma = sigma_star, C = C_star))
          save_artifacts(mdl, feats$spec, km, dataset_name, s, C = C_star, sigma = sigma_star)
        }
      }
    }
    
    kernel_results[[km]] <- list(
      aggregate = agg,
      per_seed_macro = per_seed_df,
      best_from_tuning = best_tune
    )
  }
  
  kernel_results
}


################################################################################
################################################################################
################################################################################


make_result_df <- function(res_list, dataset_name) {
  if (length(res_list) == 0) return(data.frame())
  
  rows_perf <- lapply(names(res_list), function(km) {
    classes <- names(res_list[[km]]$avg_f1)
    data.frame(
      dataset       = dataset_name,
      kernel        = km,
      class         = gsub("^Class:\\s*", "", classes),
      avg_F1_score  = as.numeric(res_list[[km]]$avg_f1),
      sd_F1_score   = as.numeric(res_list[[km]]$sd_f1),
      avg_macro_F1  = res_list[[km]]$avg_macro,
      sd_macro_F1   = res_list[[km]]$sd_macro,
      stringsAsFactors = FALSE
    )
  })
  perf_df <- do.call(rbind, rows_perf)
  
  # attach final best hyperparameters per kernel as attributes for easy access
  attr(perf_df, "final_best") <- lapply(res_list, function(x) x$final_best)
  attr(perf_df, "per_seed_best") <- lapply(res_list, function(x) x$per_seed_best)
  perf_df
}

################################################################################
################################################################################
################################################################################

build_kernel_grid <- function(kernel_method,
                              C_values,
                              gammas = NULL) {       # for svmRadial: sigma values
  if (kernel_method == "svmLinear") {
    expand.grid(C = C_values)
  } else if (kernel_method == "svmRadial") {
    if (is.null(gammas)) NULL else expand.grid(sigma = gammas, C = C_values)
  } else stop("Only svmLinear and svmRadial are supported here.")
}

get_feature_names <- function(obj) {
  # your feature_paths is a list like: list(vocab=<char>, stop_lang="de", tolower=TRUE, ...)
  if (is.list(obj) && !is.null(obj$vocab)) return(as.character(obj$vocab))
  stop("Unexpected feature object; no $vocab found (class: ", paste(class(obj), collapse=", "), ").")
}

################################################################################
################################################################################
################################################################################

predict_unlabeled_all_seeds <- function(df_unlabeled, 
                                        model_paths, 
                                        feature_paths, 
                                        text_field = "text",
                                        docid_field = "unique_id") {
  for (i in seq_along(model_paths)) {
    cat("Applying model:", model_paths[i], "\n")
    
    svm_model <- readRDS(model_paths[i])
    tf        <- readRDS(feature_paths[i])   # list with vocab + flags
    
    # --- mirror training preprocessing flags ---
    stop_lang     <- if (!is.null(tf$stop_lang)) tf$stop_lang else "de"
    tolower_flag  <- isTRUE(tf$tolower)
    rm_punct      <- isTRUE(tf$remove_punct)
    rm_numbers    <- isTRUE(tf$remove_numbers)
    rm_symbols    <- isTRUE(tf$remove_symbols)
    
    corpus_unlab <- quanteda::corpus(df_unlabeled, text_field = text_field, docid_field = docid_field)
    toks <- quanteda::tokens(corpus_unlab,
                             remove_punct   = rm_punct,
                             remove_numbers = rm_numbers,
                             remove_symbols = rm_symbols)
    if (tolower_flag) toks <- quanteda::tokens_tolower(toks)
    toks <- quanteda::tokens_remove(toks, quanteda::stopwords(stop_lang))
    
    dfm_unlab <- quanteda::dfm(toks)
    
    # --- use the actual vocabulary from feature_paths ---
    feats_train <- get_feature_names(tf)
    if (!length(feats_train)) stop("Empty training vocabulary for: ", basename(model_paths[i]))
    
    # Check overlap (fail fast if none)
    overlap <- intersect(quanteda::featnames(dfm_unlab), feats_train)
    if (length(overlap) == 0L) {
      stop("No feature overlap with training vocab for ", basename(model_paths[i]),
           ". Preprocessing mismatch or wrong stop_lang.")
    }
    
    # Align and order exactly
    dfm_matched <- quanteda::dfm_match(dfm_unlab, features = feats_train)
    X <- as.matrix(dfm_matched)
    storage.mode(X) <- "double"
    
    # Predict hard labels for svmLinear
    pred_labels <- tryCatch({
      p <- predict(svm_model, newdata = X, type = "raw")
      if (is.factor(p)) as.character(p) else as.character(p)
    }, error = function(e) {
      warning("Prediction failed: ", conditionMessage(e), " — returning NA.")
      rep(NA_character_, nrow(X))
    })
    
    seed_str <- gsub("\\D+", "", basename(model_paths[i]))
    seed_col_name <- paste0("pred_label_seed", seed_str)
    
    if (length(pred_labels) != nrow(df_unlabeled)) {
      stop("Prediction length (", length(pred_labels), 
           ") != data rows (", nrow(df_unlabeled), ").")
    }
    df_unlabeled[[seed_col_name]] <- pred_labels
  }
  
  df_unlabeled
}


train_and_evaluate_model_multiple_seeds_kernels <- function(df,
                                                            label_column,
                                                            label_levels,
                                                            seeds = c(42, 123, 999, 2024, 2025),
                                                            kernels = c("svmLinear","svmRadial"),
                                                            C_grid = 2^seq(-5, 5, by = 1),
                                                            sigma_grid = NULL,   # NULL = auto via sigest
                                                            # --- saving knobs ---
                                                            save_models   = FALSE,
                                                            dataset_name  = NULL,           # e.g., "strong_balanced"
                                                            model_base    = "models/models",
                                                            feature_base  = "models/features") {
  # -------- helpers (paths + saving) ------------------------------------
  ensure_dir <- function(path) dir.create(dirname(path), recursive = TRUE, showWarnings = FALSE)
  fmtC   <- function(x) format(x, scientific = FALSE, trim = TRUE)
  fmtSig <- function(x) format(x, scientific = TRUE, trim = TRUE)
  model_path_for <- function(km, ds, seed, C, sigma = NULL) {
    sub <- if (km == "svmLinear") {
      sprintf("%s/%s/seed_%s/C_%s/model.rds", km, ds, seed, fmtC(C))
    } else {
      sprintf("%s/%s/seed_%s/sigma_%s/C_%s/model.rds", km, ds, seed, fmtSig(sigma), fmtC(C))
    }
    file.path(model_base, sub)
  }
  feature_path_for <- function(km, ds, seed, C, sigma = NULL) {
    sub <- if (km == "svmLinear") {
      sprintf("%s/%s/seed_%s/C_%s/features.rds", km, ds, seed, fmtC(C))
    } else {
      sprintf("%s/%s/seed_%s/sigma_%s/C_%s/features.rds", km, ds, seed, fmtSig(sigma), fmtC(C))
    }
    file.path(feature_base, sub)
  }
  save_artifacts <- function(model_obj, feat_spec, km, ds, seed, C, sigma = NULL) {
    mp <- model_path_for(km, ds, seed, C, sigma)
    fp <- feature_path_for(km, ds, seed, C, sigma)
    ensure_dir(mp); ensure_dir(fp)
    saveRDS(model_obj, file = mp)
    saveRDS(feat_spec, file = fp)  # always write a spec (never empty dir)
    message(sprintf("Saved -> %s  |  %s", mp, fp))
  }
  
  # Build features (training) and return both the matrix and a spec we can reuse at inference.
  build_features_training <- function(text_vec, seed = 1, stop_lang = "de") {
    set.seed(seed)
    toks <- quanteda::tokens(text_vec,
                             remove_punct   = TRUE,
                             remove_numbers = TRUE,
                             remove_symbols = TRUE) |>
      quanteda::tokens_tolower() |>
      quanteda::tokens_remove(quanteda::stopwords(stop_lang))
    dfm_train <- quanteda::dfm(toks)
    x <- as.matrix(dfm_train)
    spec <- list(
      vocab         = quanteda::featnames(dfm_train),
      stop_lang     = stop_lang,
      tolower       = TRUE,
      remove_punct  = TRUE,
      remove_numbers= TRUE,
      remove_symbols= TRUE
    )
    list(x = x, spec = spec)
  }
  # ----------------------------------------------------------------------
  
  kernel_results <- list()
  
  for (km in kernels) {
    message("==== Kernel: ", km, " ====")
    base_grid <- build_kernel_grid(km, C_values = C_grid, gammas = sigma_grid)
    
    all_f1 <- vector("list", length(seeds))
    all_macro <- numeric(length(seeds))
    best_list <- vector("list", length(seeds))
    
    for (i in seq_along(seeds)) {
      s <- seeds[i]
      tg <- base_grid
      
      # Auto-sigma for RBF if not provided
      if (km == "svmRadial" && (is.null(tg) || !("sigma" %in% names(tg)))) {
        set.seed(s)
        corpus_obj <- quanteda::corpus(df, text_field = "text")
        tokens_obj <- quanteda::tokens(corpus_obj,
                                       remove_punct = TRUE, remove_numbers = TRUE, remove_symbols = TRUE) |>
          quanteda::tokens_tolower() |>
          quanteda::tokens_remove(quanteda::stopwords("de"))
        dfm_obj <- quanteda::dfm(tokens_obj)
        tr_idx <- sample(seq_len(quanteda::ndoc(dfm_obj)), size = 0.80 * quanteda::ndoc(dfm_obj))
        X_train_tmp <- as.data.frame(as.matrix(dfm_obj[tr_idx, ]))
        if (ncol(X_train_tmp) > 0) {
          nzv <- caret::nearZeroVar(X_train_tmp, saveMetrics = TRUE)
          if (any(nzv$zeroVar)) X_train_tmp <- X_train_tmp[, !nzv$zeroVar, drop = FALSE]
          constant_columns <- vapply(X_train_tmp, function(col) length(unique(col)) == 1, logical(1))
          if (any(constant_columns)) X_train_tmp <- X_train_tmp[, !constant_columns, drop = FALSE]
        }
        sig <- kernlab::sigest(as.matrix(X_train_tmp), frac = 1)
        sigma_candidates <- sort(unique(as.numeric(outer(as.numeric(sig), c(0.5,1,2), `*`))))
        tg <- expand.grid(sigma = sigma_candidates, C = C_grid)
      }
      
      r <- train_and_evaluate_model_v2(
        df = df,
        label_column = label_column,
        label_levels = label_levels,
        model_method = km,
        tune_grid = tg,
        tune_length = if (is.null(tg)) 3 else NULL,
        seed_val = s
      )
      all_f1[[i]]  <- r$f1_scores
      all_macro[i] <- r$macro_f1
      best_list[[i]] <- r$best_tune
    }
    
    # aggregate
    mat <- do.call(rbind, all_f1)
    agg <- list(
      avg_f1    = colMeans(mat, na.rm = TRUE),
      sd_f1     = apply(mat, 2, sd, na.rm = TRUE),
      avg_macro = mean(all_macro, na.rm = TRUE),
      sd_macro  = sd(all_macro, na.rm = TRUE)
    )
    
    # summarize best across seeds
    best_df <- dplyr::bind_rows(lapply(seq_along(seeds), function(i) {
      bt <- as.data.frame(best_list[[i]])
      bt$seed <- seeds[i]
      bt$macro_f1 <- all_macro[i]
      bt
    }))
    
    if (km == "svmLinear") {
      final_best <- best_df |>
        dplyr::group_by(C) |>
        dplyr::summarise(avg_macro_f1 = mean(macro_f1, na.rm = TRUE), .groups = "drop") |>
        dplyr::arrange(dplyr::desc(avg_macro_f1)) |>
        dplyr::slice(1)
    } else {
      final_best <- best_df |>
        dplyr::group_by(sigma, C) |>
        dplyr::summarise(avg_macro_f1 = mean(macro_f1, na.rm = TRUE), .groups = "drop") |>
        dplyr::arrange(dplyr::desc(avg_macro_f1)) |>
        dplyr::slice(1)
    }
    
    # --------- SAVE: refit exactly at final_best and persist -----------
    if (isTRUE(save_models)) {
      if (is.null(dataset_name) || !nzchar(dataset_name)) {
        stop("save_models=TRUE but dataset_name is NULL/empty. Provide dataset_name so paths can be built.")
      }
      # Prepare label factor once
      y_all <- factor(df[[label_column]], levels = label_levels)
      
      if (km == "svmLinear") {
        C_star <- final_best$C[[1]]
        tune_fixed <- data.frame(C = C_star)
        for (s in seeds) {
          feats <- build_features_training(df[["text"]], seed = s, stop_lang = "de")
          x <- feats$x
          # align y to rows of x (df rows unchanged here)
          y <- y_all
          
          ctrl <- caret::trainControl(method = "none")
          set.seed(s)
          mdl <- caret::train(x = x, y = y,
                              method = "svmLinear",
                              trControl = ctrl,
                              tuneGrid  = tune_fixed)
          save_artifacts(mdl, feats$spec, km, dataset_name, s, C = C_star, sigma = NULL)
        }
      } else if (km == "svmRadial") {
        C_star <- final_best$C[[1]]
        sigma_star <- final_best$sigma[[1]]
        tune_fixed <- data.frame(sigma = sigma_star, C = C_star)
        for (s in seeds) {
          feats <- build_features_training(df[["text"]], seed = s, stop_lang = "de")
          x <- feats$x
          y <- y_all
          
          ctrl <- caret::trainControl(method = "none")
          set.seed(s)
          mdl <- caret::train(x = x, y = y,
                              method = "svmRadial",
                              trControl = ctrl,
                              tuneGrid  = tune_fixed)
          save_artifacts(mdl, feats$spec, km, dataset_name, s, C = C_star, sigma = sigma_star)
        }
      }
    }
    # -------------------------------------------------------------------
    
    kernel_results[[km]] <- c(agg, list(
      per_seed_best = best_df,
      final_best = final_best
    ))
  }
  
  kernel_results
}

